import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

torch.manual_seed(43)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def target_function(x):
	return x**2
def own_sin(x):
	return torch.sin(x) + torch.cos(x)
# Pretrained Multi-Layer Perceptron (MLP) model
class PretrainedMLP(nn.Module):
	def __init__(self):
		super(PretrainedMLP, self).__init__()
		self.layers = nn.Sequential(
			nn.Linear(20, 100),
			nn.SiLU(),
			nn.Linear(100, 200),
			nn.SiLU(),
			nn.Linear(200, 100),
			nn.SiLU(),
			nn.Linear(100, 20)
		)
	def forward(self, x):
		return self.layers(x)

# New Neural Network with trainable parameters
class SuperMLP(nn.Module):
	def __init__(self):
		super(SuperMLP, self).__init__()
		# 19 trainable parameters (learnable)
		self.trainable_params1 = nn.Parameter(torch.randn(19, 1))
		self.trainable_params2 = nn.Parameter(torch.randn(19, 1))
		self.trainable_params3 = nn.Parameter(torch.randn(19, 1))
	def forward(self, x):
		batch_size = x.shape[1]
		trainable_expanded1 = self.trainable_params1.expand(19, batch_size).to(device)
		trainable_expanded2 = self.trainable_params1.expand(19, batch_size).to(device)
		trainable_expanded3 = self.trainable_params1.expand(19, batch_size).to(device)
		# Concatenate expanded trainable parameters with input
		pnn_input = torch.cat([trainable_expanded1, x], dim=0)
		# Pass through the pretrained model
		layer1_output = own_sin(pnn_input.T).mean(dim=0).reshape(1,20)
		layer2_input = torch.cat([trainable_expanded2, layer1_output], dim=0)
		# Pass through the pretrained model
		layer2_output = own_sin(layer2_input.T).mean(dim=0).reshape(1,20)
		layer3_input = torch.cat([trainable_expanded3, layer2_output], dim=0)
		# Pass through the pretrained model
		layer3_output = own_sin(layer3_input.T).mean(dim=0).reshape(1,20)
		pnn_output = layer3_output
		# pnn_output = torch.sin(pnn_input.T) + torch.cos(pnn_input.T)
		

		return pnn_output


if __name__ == '__main__':
	# Load the pretrained model
	trained_pnn = PretrainedMLP().to(device)
	trained_pnn.load_state_dict(torch.load('trained_PNN.pth',weights_only=True, map_location=device))
	trained_pnn.eval()
	# Initialize the new neural network
	new_nn = SuperMLP().to(device)

	# Print the total number of parameters in the new model
	total_params = sum(p.numel() for p in new_nn.parameters())
	print(f"Total parameters: {total_params}")

	# Define the optimizer
	optimizer = optim.Adam(new_nn.parameters(), lr=0.0001)

	# Training loop

	n = 20  # Number of training points
	def weights_init(m):
		if isinstance(m, nn.Linear):
			torch.nn.init.xavier_uniform()
		
	# Generate training data
	x_train = torch.linspace(-1, 1, n).view(1, n).to(device)
	y_train = target_function(x_train)
	

	plt.figure(figsize=(6,5 )) 

	num_epochs = 20000
	new_nn.apply(weights_init)
	for epoch in range(num_epochs):
		optimizer.zero_grad()
		y_pred = new_nn(x_train)
		loss = ((y_pred - y_train) ** 2).mean()
		loss.backward()
		optimizer.step()

		if epoch % 1000 == 0:
			print(f"Epoch [{epoch}/{num_epochs}], Loss: {loss.item():.4f}")

	
		if epoch == 5000 or epoch == 10000 or epoch == 15000 or epoch == 20000:
		# Plot the original function and the predicted function
			x_test = torch.linspace(-1, 1, 20).view(1, 20).to(device)  # More points for a smoother curve
			y_test = target_function(x_test).cpu().detach().numpy()  # Original function values
			y_pred = new_nn(x_test).mean(0).cpu().detach().numpy()  # Predicted function values
			plt.plot(x_test.cpu().numpy().flatten(), y_pred.flatten(), label=f"No. of Epochs: {epoch}")
	plt.scatter(x_test.cpu().numpy().flatten(), y_test.flatten(), label="Original Function", linestyle='dashed',color='red')
	plt.xlabel("x")
	plt.ylabel("y")
	plt.title("Comparison of Original and Predicted Functions")
	plt.legend(loc='upper center')
	print("Training Complete!")
	plt.show()
